{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C:\\Users\\user\\Drive\\s-hero\\code\\yolov5_classifier\n"
     ]
    }
   ],
   "source": [
    "%cd C:\\Users\\user\\Drive\\s-hero\\code\\yolov5_classifier\n",
    "!pip install -qr requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pytorch version :  1.9.1\n",
      "device: cuda:0\n",
      "gpus: 1\n",
      "graphic name: NVIDIA GeForce RTX 3080 Ti\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "from glob import glob\n",
    "import gc\n",
    "import yaml\n",
    "\n",
    "import matplotlib \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import torch \n",
    "from torch.utils.data import Dataset, DataLoader, RandomSampler\n",
    "from torch.utils.data.dataset import random_split\n",
    "\n",
    "from PIL import Image\n",
    "import PIL\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import transforms as T\n",
    "from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torchmetrics.functional import accuracy\n",
    "from pytorch_model_summary import summary as pms\n",
    "from torchinfo import summary as info\n",
    "from torchsummary import summary \n",
    "from yolov5_classifier import *\n",
    "from yolov5_classifier.train import *\n",
    "from yolov5_classifier.val import *\n",
    "from yolov5_classifier.detect import *\n",
    "from yolov5_classifier.export import *\n",
    "from yolov5_classifier.hubconf import *\n",
    "\n",
    "from yolov5_classifier.models import *\n",
    "from yolov5_classifier.models.common import *\n",
    "from yolov5_classifier.models.experimental import *\n",
    "from yolov5_classifier.models.yolo import Model\n",
    "from yolov5_classifier.models.yolo import *\n",
    "from yolov5_classifier.utils import *\n",
    "\n",
    "import random \n",
    "random.seed(7)\n",
    "\n",
    "USE_CUDA = torch.cuda.is_available()\n",
    "device = torch.device('cuda:0' if USE_CUDA else 'cpu')\n",
    "print('pytorch version : ',torch.__version__)\n",
    "print('device:',device)\n",
    "print('gpus:', torch.cuda.device_count())\n",
    "print('graphic name:', torch.cuda.get_device_name())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r'C:\\Users\\user\\Drive\\s-hero\\Code\\yolov5_classifier\\models\\yolov5m.yaml', 'r') as f:\n",
    "    v5m = yaml.load(f, Loader=yaml.FullLoader)\n",
    "\n",
    "with open(r'C:\\Users\\user\\Drive\\s-hero\\Code\\yolov5\\models\\yolov5m.yaml', 'r') as f:\n",
    "    v5m_org = yaml.load(f, Loader=yaml.FullLoader)\n",
    "\n",
    "org_model = Model(v5m_org, 3, 8)\n",
    "model = Model(v5m, 3, 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in C:\\Users\\user/.cache\\torch\\hub\\ultralytics_yolov5_master\n",
      "YOLOv5  0c93333 torch 1.9.1 CUDA:0 (NVIDIA GeForce RTX 3080 Ti, 12288.0MB)\n",
      "\n",
      "Fusing layers... \n",
      "Model Summary: 290 layers, 21172173 parameters, 0 gradients, 49.0 GFLOPs\n"
     ]
    }
   ],
   "source": [
    "model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True, autoshape=False)\n",
    "model.model = model.model[:8]\n",
    "m = model.model[-1]  # last layer\n",
    "ch = m.conv.in_channels if hasattr(m, 'conv') else sum([x.in_channels for x in m.m])  # ch into module\n",
    "c = Classify(ch, 8)  # Classify()\n",
    "c.i, c.f, c.type = m.i, m.f, 'models.common.Classify'  # index, from, type\n",
    "model.model[-1] = c  # replace\n",
    "for p in model.parameters():\n",
    "        p.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "===============================================================================================\n",
       "Layer (type:depth-idx)                        Output Shape              Param #\n",
       "===============================================================================================\n",
       "Model                                         --                        --\n",
       "├─Sequential: 1                               --                        --\n",
       "│    └─Conv: 2-1                              [32, 48, 510, 990]        --\n",
       "│    │    └─Conv2d: 3-1                       [32, 48, 510, 990]        5,232\n",
       "│    │    └─SiLU: 3-2                         [32, 48, 510, 990]        --\n",
       "│    └─Conv: 2-2                              [32, 96, 255, 495]        --\n",
       "│    │    └─Conv2d: 3-3                       [32, 96, 255, 495]        41,568\n",
       "│    │    └─SiLU: 3-4                         [32, 96, 255, 495]        --\n",
       "│    └─C3: 2-3                                [32, 96, 255, 495]        --\n",
       "│    │    └─Conv: 3-5                         [32, 48, 255, 495]        4,656\n",
       "│    │    └─Sequential: 3-6                   [32, 48, 255, 495]        46,272\n",
       "│    │    └─Conv: 3-7                         [32, 48, 255, 495]        4,656\n",
       "│    │    └─Conv: 3-8                         [32, 96, 255, 495]        9,312\n",
       "│    └─Conv: 2-4                              [32, 192, 128, 248]       --\n",
       "│    │    └─Conv2d: 3-9                       [32, 192, 128, 248]       166,080\n",
       "│    │    └─SiLU: 3-10                        [32, 192, 128, 248]       --\n",
       "│    └─C3: 2-5                                [32, 192, 128, 248]       --\n",
       "│    │    └─Conv: 3-11                        [32, 96, 128, 248]        18,528\n",
       "│    │    └─Sequential: 3-12                  [32, 96, 128, 248]        369,408\n",
       "│    │    └─Conv: 3-13                        [32, 96, 128, 248]        18,528\n",
       "│    │    └─Conv: 3-14                        [32, 192, 128, 248]       37,056\n",
       "│    └─Conv: 2-6                              [32, 384, 64, 124]        --\n",
       "│    │    └─Conv2d: 3-15                      [32, 384, 64, 124]        663,936\n",
       "│    │    └─SiLU: 3-16                        [32, 384, 64, 124]        --\n",
       "│    └─C3: 2-7                                [32, 384, 64, 124]        --\n",
       "│    │    └─Conv: 3-17                        [32, 192, 64, 124]        73,920\n",
       "│    │    └─Sequential: 3-18                  [32, 192, 64, 124]        2,214,144\n",
       "│    │    └─Conv: 3-19                        [32, 192, 64, 124]        73,920\n",
       "│    │    └─Conv: 3-20                        [32, 384, 64, 124]        147,840\n",
       "│    └─Classify: 2-8                          [32, 8]                   --\n",
       "│    │    └─AdaptiveAvgPool2d: 3-21           [32, 384, 1, 1]           --\n",
       "│    │    └─Conv2d: 3-22                      [32, 8, 1, 1]             3,080\n",
       "│    │    └─Flatten: 3-23                     [32, 8]                   --\n",
       "===============================================================================================\n",
       "Total params: 3,898,136\n",
       "Trainable params: 3,898,136\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (T): 1.94\n",
       "===============================================================================================\n",
       "Input size (MB): 775.53\n",
       "Forward/backward pass size (MB): 39657.97\n",
       "Params size (MB): 15.59\n",
       "Estimated Total Size (MB): 40449.09\n",
       "==============================================================================================="
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#info(model, (32, 3, 1020, 1980), device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "#}ckpt = torch.load(r'C:\\Users\\user\\Drive\\s-hero\\Code\\yolov5\\runs\\train\\m_b32_300epoch_evo\\weights\\best.pt')\n",
    "ckpt = torch.load(r'C:\\Users\\user\\Drive\\s-hero\\Code\\yolov5_classifier\\runs\\train\\m_b16_300epoch\\weights\\best.pt')\n",
    "weight = ckpt['model'].state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['model.0.conv.weight', 'model.0.conv.bias', 'model.1.conv.weight', 'model.1.conv.bias', 'model.2.cv1.conv.weight', 'model.2.cv1.conv.bias', 'model.2.cv2.conv.weight', 'model.2.cv2.conv.bias', 'model.2.cv3.conv.weight', 'model.2.cv3.conv.bias', 'model.2.m.0.cv1.conv.weight', 'model.2.m.0.cv1.conv.bias', 'model.2.m.0.cv2.conv.weight', 'model.2.m.0.cv2.conv.bias', 'model.2.m.1.cv1.conv.weight', 'model.2.m.1.cv1.conv.bias', 'model.2.m.1.cv2.conv.weight', 'model.2.m.1.cv2.conv.bias', 'model.3.conv.weight', 'model.3.conv.bias', 'model.4.cv1.conv.weight', 'model.4.cv1.conv.bias', 'model.4.cv2.conv.weight', 'model.4.cv2.conv.bias', 'model.4.cv3.conv.weight', 'model.4.cv3.conv.bias', 'model.4.m.0.cv1.conv.weight', 'model.4.m.0.cv1.conv.bias', 'model.4.m.0.cv2.conv.weight', 'model.4.m.0.cv2.conv.bias', 'model.4.m.1.cv1.conv.weight', 'model.4.m.1.cv1.conv.bias', 'model.4.m.1.cv2.conv.weight', 'model.4.m.1.cv2.conv.bias', 'model.4.m.2.cv1.conv.weight', 'model.4.m.2.cv1.conv.bias', 'model.4.m.2.cv2.conv.weight', 'model.4.m.2.cv2.conv.bias', 'model.4.m.3.cv1.conv.weight', 'model.4.m.3.cv1.conv.bias', 'model.4.m.3.cv2.conv.weight', 'model.4.m.3.cv2.conv.bias', 'model.5.conv.weight', 'model.5.conv.bias', 'model.6.cv1.conv.weight', 'model.6.cv1.conv.bias', 'model.6.cv2.conv.weight', 'model.6.cv2.conv.bias', 'model.6.cv3.conv.weight', 'model.6.cv3.conv.bias', 'model.6.m.0.cv1.conv.weight', 'model.6.m.0.cv1.conv.bias', 'model.6.m.0.cv2.conv.weight', 'model.6.m.0.cv2.conv.bias', 'model.6.m.1.cv1.conv.weight', 'model.6.m.1.cv1.conv.bias', 'model.6.m.1.cv2.conv.weight', 'model.6.m.1.cv2.conv.bias', 'model.6.m.2.cv1.conv.weight', 'model.6.m.2.cv1.conv.bias', 'model.6.m.2.cv2.conv.weight', 'model.6.m.2.cv2.conv.bias', 'model.6.m.3.cv1.conv.weight', 'model.6.m.3.cv1.conv.bias', 'model.6.m.3.cv2.conv.weight', 'model.6.m.3.cv2.conv.bias', 'model.6.m.4.cv1.conv.weight', 'model.6.m.4.cv1.conv.bias', 'model.6.m.4.cv2.conv.weight', 'model.6.m.4.cv2.conv.bias', 'model.6.m.5.cv1.conv.weight', 'model.6.m.5.cv1.conv.bias', 'model.6.m.5.cv2.conv.weight', 'model.6.m.5.cv2.conv.bias', 'model.7.conv.weight', 'model.7.conv.bias'])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(weight)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "1f24b2f6551b049df64cb25183e90251dea2c63deab47566e118362740cff14f"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('torch': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
